In [1]:
import os
os.chdir('../')

In [2]:
%pwd

'/home/milad/projects/medical-nlp-pipeline'

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    nlp_base_model_path: Path
    nlp_updated_base_model_path: Path
    params_classes: int
    params_pretrained: bool

In [4]:
from medical_nlp.constants import *
from medical_nlp.utils.common import read_yaml, create_directories

In [5]:
class configurationManager:
    def __init__(self, config_file_path = CONFIG_FILE_PATH,
                 params_file_path = PARAMS_FILE_PATH):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)
        
        create_directories([self.config.artifacts_root])
    
    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        
        create_directories([config.root_dir])
        
        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir = config.root_dir,
            nlp_base_model_path = config.nlp_base_model_path,
            nlp_updated_base_model_path = config.nlp_updated_base_model_path,
            params_classes = self.params.CLASSES,
            params_pretrained = self.params.PRETRAINED
        )
        
        return prepare_base_model_config

In [3]:
import os
import torch
from torchsummary import summary
from medical_nlp import logger
import transformers
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
class BERT(nn.Module):
    def __init__(self, num_classes):
        super(BERT, self).__init__()
        self.bert_model = transformers.BertModel.from_pretrained("bert-base-uncased")
        self.out = nn.Linear(self.bert_model.pooler.dense.in_features, num_classes)
        
    def forward(self,ids,mask=None,token_type_ids=None):
        _,o2= self.bert_model(ids,attention_mask=mask,token_type_ids=token_type_ids, return_dict=False)
        
        out= self.out(o2)
        
        return out
    

In [9]:
bert_model = transformers.BertModel.from_pretrained("bert-base-uncased")



In [6]:
summary(model)

Layer (type:depth-idx)                   Param #
├─BertModel: 1-1                         --
|    └─BertEmbeddings: 2-1               --
|    |    └─Embedding: 3-1               23,440,896
|    |    └─Embedding: 3-2               393,216
|    |    └─Embedding: 3-3               1,536
|    |    └─LayerNorm: 3-4               1,536
|    |    └─Dropout: 3-5                 --
|    └─BertEncoder: 2-2                  --
|    |    └─ModuleList: 3-6              85,054,464
|    └─BertPooler: 2-3                   --
|    |    └─Linear: 3-7                  590,592
|    |    └─Tanh: 3-8                    --
├─Linear: 1-2                            2,307
Total params: 109,484,547
Trainable params: 109,484,547
Non-trainable params: 0


Layer (type:depth-idx)                   Param #
├─BertModel: 1-1                         --
|    └─BertEmbeddings: 2-1               --
|    |    └─Embedding: 3-1               23,440,896
|    |    └─Embedding: 3-2               393,216
|    |    └─Embedding: 3-3               1,536
|    |    └─LayerNorm: 3-4               1,536
|    |    └─Dropout: 3-5                 --
|    └─BertEncoder: 2-2                  --
|    |    └─ModuleList: 3-6              85,054,464
|    └─BertPooler: 2-3                   --
|    |    └─Linear: 3-7                  590,592
|    |    └─Tanh: 3-8                    --
├─Linear: 1-2                            2,307
Total params: 109,484,547
Trainable params: 109,484,547
Non-trainable params: 0

In [None]:
class PrepareBaseModel():
    def __init__(self, config: PrepareBaseModelConfig):
        self.config = config
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    def get_base_model(self):
        base_bert_model = transformers.BertModel.from_pretrained("bert-base-uncased")
        self.save_model(checkpoint=base_bert_model, path=self.config.nlp_base_model_path)
        nlp_model = BERT(self.config.params_classes)
        nlp_model.to(self.device)
        return nlp_model
    
    @staticmethod
    def _prepare_full_model(model, freeze_till, freeze_all=False):
        if freeze_all:
            for param in model.bert_model.parameters():
                param.requires_grad = False
        
        elif (freeze_till is not None) and (freeze_till > 0):
            for param in model.bert_model.parameters()[:-freeze_till]:
                param.requires_grad = False
        
        return model
    
    def update_base_model(self):
        self.full_model = self._prepare_full_model(
            model=self.get_base_model(),
            classes=self.config.params_classes,
            freeze_all=True,
            freeze_till=None
        )
        
        self.full_model.to(self.device)
        summary(self.full_model, input_size=tuple(self.config.params_image_size), device=self.device)
        self.save_model(checkpoint=self.full_model, path=self.config.nlp_updated_base_model_path)
        logger.info(f"saved updated model to {str(self.config.root_dir)}")

    
    @staticmethod
    def save_model(checkpoint: dict, path: Path):
        torch.save(checkpoint, path)
    