In [3]:
!pip install pytorch-pretrained-bert




In [4]:
!pip install tqdm



In [5]:
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor
from PIL import Image
import torch.nn.functional as F
from transformers import BertConfig, BertModel, BertPreTrainedModel, AutoConfig


import torch
import torch.nn as nn
from transformers import BertConfig, BertModel
import functools
import json
import os
from collections import Counter

import torch
import torchvision.transforms as transforms
from pytorch_pretrained_bert import BertTokenizer, BertAdam
import json
import numpy as np
import os
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader

import contextlib

import random
import torch.optim as optim
import tqdm
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score

# DATA

In [6]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(0)

In [7]:
class Args:
    def __init__(self):
        self.seed = 123
        self.batch_sz = 8
        self.max_epochs = 20
        self.task_type = "multilabel"
        self.n_workers = 4
        self.patience = 20
        
        output_path = 'output'
        self.savedir = "/kaggle/working/"
        self.save_name = 'mimic_par'
        
        self.loaddir = 'path/to/pre-trained_model'
        self.name = "scenario_name"
        
        self.openi = False
        self.data_path = '/kaggle/input/adip-hcmus-mimiccxr/'
        self.Train_dset_name = 'data/csv/train.jsonl'
        self.Valid_dset_name = 'data/csv/valid.jsonl'
        
        self.embed_sz = 768
        self.hidden_sz = 768
        self.bert_model = "bert-base-uncased"
        self.init_model = "bert-base-uncased"
        
        self.drop_img_percent = 0.0
        self.dropout = 0.1
        
        self.freeze_img = 0
        self.freeze_txt = 0
        
        self.freeze_img_all = False
        self.freeze_txt_all = False
        
        self.glove_path = "/path/to/glove_embeds/glove.840B.300d.txt"
        self.gradient_accumulation_steps = 2
        self.hidden = []
        
        self.img_embed_pool_type = "avg"
        self.img_hidden_sz = 2048
        self.include_bn = True
        
        self.lr = 1e-3
        self.lr_factor = 0.75
        self.lr_patience = 5
        
        self.max_seq_len = 512
        self.num_image_embeds = 256
        
        self.warmup = 0.1
        self.weight_classes = 1
    
args = Args()

In [8]:
class Vocab(object):
    def __init__(self, emptyInit=False):
        if emptyInit:
            self.stoi, self.itos, self.vocab_sz = {}, [], 0
        else:
            self.stoi = {
                w: i
                for i, w in enumerate(["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"])
            }
            self.itos = [w for w in self.stoi]
            self.vocab_sz = len(self.itos)

    def add(self, words):
        cnt = len(self.itos)
        for w in words:
            if w in self.stoi:
                continue
            self.stoi[w] = cnt
            self.itos.append(w)
            cnt += 1
        self.vocab_sz = len(self.itos)

In [9]:
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
    """Context manager which seeds the NumPy PRNG with the specified seed and
    restores the state afterward"""
    if seed is None:
        yield       # Không làm gì và kết thúc ngay lập tức
        return
    if len(addl_seeds) > 0:
        seed = int(hash((seed, *addl_seeds)) % 1e6)
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

In [10]:
class JsonlDataset(Dataset):
    def __init__(self, data_path, tokenizer, transforms, vocab, args, is_train=True):
        if is_train:
            with open(data_path + args.Train_dset_name, "r") as file:
                self.data = json.load(file)
        else:
            with open(data_path + args.Valid_dset_name, "r") as file:
                self.data = json.load(file)
        self.data_dir = os.path.dirname(data_path)
        self.tokenizer = tokenizer
        self.args = args
        self.vocab = vocab
        self.n_classes = len(args.labels)
        self.text_start_token = ["[SEP]"]

        with numpy_seed(0):
            for row in self.data:
                if np.random.random() < args.drop_img_percent:
                    row["img"] = None

        self.max_seq_len = args.max_seq_len
        self.max_seq_len -= args.num_image_embeds

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sentence = (
            self.text_start_token
            + self.tokenizer(self.data[index]["text"])[: (self.max_seq_len - 1)]
            + self.text_start_token
        )
        segment = torch.zeros(len(sentence))
        sentence = torch.LongTensor(
            [
                self.vocab.stoi[w] if w in self.vocab.stoi else self.vocab.stoi["[UNK]"]
                for w in sentence
            ]
        )
        if self.args.task_type == "multilabel":
            label = torch.zeros(self.n_classes)
            if self.data[index]["label"] == '':
                self.data[index]["label"] = "'Others'"
            else:
                pass
            label[  # Vector 14 chiều
                [self.args.labels.index(tgt) for tgt in self.data[index]["label"]]
            ] = 1
        else:
            pass

        image = None
        if self.data[index]["img"]:
            image = Image.open(
                os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
        else:
            image = Image.fromarray(128 * np.ones((256, 256, 3), dtype=np.uint8))
        image = self.transforms(image)

        # The first SEP is part of Image Token.
        segment = segment[1:]
        sentence = sentence[1:]
        # The first segment (0) is of images.
        segment += 1

        return sentence, segment, image, label

In [11]:
def get_transforms(args):
    if args.openi:
        return transforms.Compose(
            [
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
    else:
        return transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]
    )

In [12]:
def get_labels_and_frequencies(path):
    # Khởi tạo một Counter để đếm tần suất xuất hiện của các nhãn
    label_freqs = Counter()
    with open(path, "r") as file:
        data = json.load(file)
    data_labels = [line["label"] for line in data]
    if type(data_labels) == list:
        for label_row in data_labels:
            if label_row == '':
                label_row = ["'Others'"]

            label_freqs.update(label_row)
    else:
        pass
    return list(label_freqs.keys()), label_freqs

In [13]:
def get_vocab(args):
    vocab = Vocab()
    bert_tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=True
    )
    vocab.stoi = bert_tokenizer.vocab
    vocab.itos = bert_tokenizer.ids_to_tokens
    vocab.vocab_sz = len(vocab.itos)

    return vocab

In [14]:
def collate_fn(batch, args):
    lens = [len(row[0]) for row in batch]
    bsz, max_seq_len = len(batch), max(lens)

    mask_tensor = torch.zeros(bsz, max_seq_len).long()
    text_tensor = torch.zeros(bsz, max_seq_len).long()
    segment_tensor = torch.zeros(bsz, max_seq_len).long()

    img_tensor = None
    img_tensor = torch.stack([row[2] for row in batch])

    if args.task_type == "multilabel":
        # Multilabel case
        tgt_tensor = torch.stack([row[3] for row in batch])
    else:
        # Single Label case
        tgt_tensor = torch.cat([row[3] for row in batch]).long()

    for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
        tokens, segment = input_row[:2]
        text_tensor[i_batch, :length] = tokens
        segment_tensor[i_batch, :length] = segment
        mask_tensor[i_batch, :length] = 1

    return text_tensor, segment_tensor, mask_tensor, img_tensor, tgt_tensor

In [15]:
def get_data_loaders(args):
    tokenizer = (
        BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True).tokenize)

    transforms = get_transforms(args)

    args.labels, args.label_freqs = get_labels_and_frequencies(
        os.path.join(args.data_path, args.Train_dset_name)
    )

    vocab = get_vocab(args)
    args.vocab = vocab
    args.vocab_sz = vocab.vocab_sz
    args.n_classes = len(args.labels)

    train_dataset = JsonlDataset(
        os.path.join(args.data_path),
        tokenizer,
        transforms,
        vocab,
        args,
    )

    args.train_data_len = len(train_dataset)

    val_dataset = JsonlDataset(
        os.path.join(args.data_path),
        tokenizer,
        transforms,
        vocab,
        args,
        is_train=False,
    )

    collate = functools.partial(collate_fn, args=args)

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_sz,
        shuffle=False,
        num_workers=args.n_workers,
        collate_fn=collate,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_sz,
        shuffle=False,
        num_workers=args.n_workers,
        collate_fn=collate,
    )

    return train_dataset, train_loader, val_dataset, val_loader  # , test


In [16]:
train_dataset, train_dataloader, val_dataset, val_dataloader = get_data_loaders(args)

# MODEL

In [17]:
# args = {
#     "img_hidden_sz" : 2048,
#     "hidden_sz" : 768,
#     "dropout" : 0.1 ,
#     "num_image_embeds" : 256,
#     "init_model" : "bert-base-uncased",
#     "n_classes" : 14,
#     "img_embed_pool_type" : "avg" ,
#     "vocab" : []
# }

In [71]:
class ImageEncoder(nn.Module):
    def __init__(self, args):
        super(ImageEncoder, self).__init__()
        self.args = args
        model = torchvision.models.resnet50(pretrained=True)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)

        pool_func = (
            nn.AdaptiveAvgPool2d
            if args.img_embed_pool_type == "avg"
            else nn.AdaptiveMaxPool2d
        )


    def forward(self, x):
        # Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048

        # out = self.pool(self.model(x))
        # out = torch.flatten(out, start_dim=2)
        # out = out.transpose(1, 2).contiguous()
        
        out = self.model(x)
        out = torch.flatten(out, start_dim=2) #out torch.Size([100, 2048, 3])
        out = out.transpose(1, 2).contiguous() #out torch.Size([100, 3, 2048])

        # print("out.size()",out.size())
        # input("STOP!!!")
        

        return out  # BxNx2048



In [72]:
class ImageBertEmbeddings(nn.Module):
    def __init__(self, args, embeddings):
        super(ImageBertEmbeddings, self).__init__()
        self.args = args
        print(args)
        self.img_embeddings = nn.Linear(args.img_hidden_sz, args.hidden_sz)
        self.position_embeddings = embeddings.position_embeddings
        self.token_type_embeddings = embeddings.token_type_embeddings
        self.word_embeddings = embeddings.word_embeddings
        self.LayerNorm = embeddings.LayerNorm
        self.dropout = nn.Dropout(p=args.dropout)
        
    def forward(self, input_imgs, token_type_ids):
        bsz = input_imgs.size(0)
        seq_length = self.args.num_image_embeds + 2  # +2 for CLS and SEP Token

        cls_id = torch.LongTensor([self.args.vocab.stoi["[CLS]"]]).cuda()
        cls_id = cls_id.unsqueeze(0).expand(bsz, 1)
        cls_token_embeds = self.word_embeddings(cls_id)

        sep_id = torch.LongTensor([self.args.vocab.stoi["[SEP]"]]).cuda()
        sep_id = sep_id.unsqueeze(0).expand(bsz, 1)
        sep_token_embeds = self.word_embeddings(sep_id)
        
        #print(input_imgs.shape)
        imgs_embeddings = self.img_embeddings(input_imgs)
        print(imgs_embeddings)
        token_embeddings = torch.cat(
            [cls_token_embeds, imgs_embeddings, sep_token_embeds], dim=1)
        #print(imgs_embeddings.shape)
        
        position_ids = torch.arange(seq_length, dtype=torch.long).cuda()
        position_ids = position_ids.unsqueeze(0).expand(bsz, seq_length)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        #print(args.hidden_sz)
        embeddings = token_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


In [73]:
class MultimodalBertEncoder(BertPreTrainedModel):
    def __init__(self, model_config, args, configs):
        super().__init__(model_config)
        self.args = args
        self.configs = configs
        bert = BertModel(model_config)
        
        self.txt_embeddings = bert.embeddings
        self.img_embeddings = ImageBertEmbeddings(args, self.txt_embeddings)
        self.img_encoder = ImageEncoder(args)
        self.encoder = bert.encoder
        self.pooler = bert.pooler
        #self.clf = nn.Linear(args.hidden_sz, args.n_classes)

    def forward(self, input_txt, attention_mask, segment, input_img):
        bsz = input_txt.size(0)
        attention_mask = torch.cat(
            [
                torch.ones(bsz, self.args.num_image_embeds + 2).long().cuda(),
                attention_mask,
            ],
            dim=1)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        try:
            extended_attention_mask = extended_attention_mask.to(
                dtype=next(self.parameters()).dtype)  # fp16 compatibility
        except StopIteration:
            extended_attention_mask = extended_attention_mask.to(dtype=torch.float16)

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        img_tok = (
            torch.LongTensor(input_txt.size(0), self.args.num_image_embeds + 2)
            .fill_(0)
            .cuda())
        img = self.img_encoder(input_img)  # BxNx3x224x224 -> BxNx2048

        
        img_embed_out = self.img_embeddings(img, img_tok)
        txt_embed_out = self.txt_embeddings(input_txt, segment)
        encoder_input = torch.cat([img_embed_out, txt_embed_out], 1)  # Bx(TEXT+IMG)xHID
        encoded_layers = self.encoder(encoder_input, extended_attention_mask)
        return self.pooler(encoded_layers[-1])

In [74]:
class MultimodalBertClf(BertPreTrainedModel):
    def __init__(self, model_config, args, configs):
        super().__init__(model_config)
        self.enc = MultimodalBertEncoder(model_config, args, configs)
        self.fc = nn.Linear(args.hidden_sz, args.n_classes)

    def forward(self, txt, mask, segment, img):
        x = self.enc(txt, mask, segment, img)
        return self.fc(x)
    

In [75]:
# Kiểm tra xem có GPU có sẵn hay không
if torch.cuda.is_available():
    device = torch.device('cuda')  # Sử dụng GPU
else:
    device = torch.device('cpu')  # Sử dụng CPU


In [31]:
# Load the model state dict
model_state_dict = torch.load("/kaggle/input/medvill-weight/pytorch_model.bin")

model_config = AutoConfig.from_pretrained("/kaggle/input/medvill-weight")  
model = MultimodalBertClf.from_pretrained("/kaggle/input/medvill-weight", state_dict=model_state_dict,
                     args=args, configs=model_config, local_files_only=True).to(device)

# Get the state dict of the model after it's been created
model_dict = model.state_dict()

# Compare the keys of the two state dicts
pretrained_keys = set(model_state_dict.keys())
model_keys = set(model_dict.keys())

# Find the keys that are in the model state dict but not in the pretrained state dict
newly_initialized_keys = model_keys - pretrained_keys

# Print the newly initialized keys
print(newly_initialized_keys)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


<__main__.Args object at 0x7842e6ada110>


100%|██████████| 97.8M/97.8M [00:00<00:00, 132MB/s] 
Some weights of MultimodalBertClf were not initialized from the model checkpoint at /kaggle/input/medvill-weight and are newly initialized: ['fc.bias', 'fc.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'fc.bias', 'fc.weight', 'enc.img_embeddings.word_embeddings.weight'}


# TRAINING

In [32]:
# Loss function
criterion =  nn.BCEWithLogitsLoss()

total_steps = (
        args.train_data_len
        / args.batch_sz
        / args.gradient_accumulation_steps
        * args.max_epochs)

# Get optimizer
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
    {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }]
optimizer = BertAdam(
    optimizer_grouped_parameters,
    lr=args.lr,
    warmup=args.warmup,
    t_total=total_steps)

# Get scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=args.lr_patience, verbose=True, factor=args.lr_factor)



In [33]:
for param in model.enc.img_encoder.parameters():
    param.requires_grad = args.freeze_img_all
for param in model.enc.encoder.parameters():
    param.requires_grad = args.freeze_txt_all

In [34]:
def model_forward(model, args, criterion, batch, device):
    txt, segment, mask, img, tgt = batch
    txt, img = txt.to(device), img.to(device)
    mask, segment = mask.to(device), segment.to(device)
    out = model(txt, mask, segment, img)

    tgt = tgt.to(device)
    loss = criterion(out, tgt)
    return loss, out, tgt

In [35]:
def model_eval(data, model, args, criterion, device, store_preds=False):
    with torch.no_grad():
        losses, preds, preds_bool, tgts, outAUROC = [], [], [], [], []
        for batch in data:
            loss, out, tgt = model_forward(model, args, criterion, batch, device)
            losses.append(loss.item())
            
            pred_bool = torch.sigmoid(out).cpu().detach().numpy() > 0.5
            pred = torch.sigmoid(out).cpu().detach().numpy()
            preds.append(pred)
            preds_bool.append(pred_bool)
            tgt = tgt.cpu().detach().numpy()
            tgts.append(tgt)

    metrics = {"loss": np.mean(losses)}
    classACC = dict()
    
    tgts = np.vstack(tgts)
    preds = np.vstack(preds)
    preds_bool = np.vstack(preds_bool)

    for i in range(args.n_classes):
        try:
            outAUROC.append(roc_auc_score(tgts[:, i], preds[:, i]))
        except ValueError:
            outAUROC.append(0)
            pass
    for i in range(0, len(outAUROC)):
        assert args.n_classes == len(outAUROC)
        classACC[args.labels[i]] = outAUROC[i]

    metrics["micro_roc_auc"] = roc_auc_score(tgts, preds, average="micro")
    metrics["macro_roc_auc"] = roc_auc_score(tgts, preds, average="macro")
    metrics["macro_f1"] = f1_score(tgts, preds_bool, average="macro")
    metrics["micro_f1"] = f1_score(tgts, preds_bool, average="micro")
    print('micro_auc:', metrics["micro_roc_auc"])
    print('micro_f1:', metrics["micro_f1"])
    print('-----------------------------------------------------')
   
    return metrics, classACC, tgts, preds

In [39]:
start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf
for i_epoch in range(start_epoch, args.max_epochs):
    train_losses = []
    model.train()
    # model.train()
    optimizer.zero_grad()

    for batch in tqdm.tqdm(train_dataloader, total=len(train_dataloader)):
        loss, out, target = model_forward(model, args, criterion, batch, device)
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        train_losses.append(loss.item())
        loss.backward()
        global_step += 1
        if global_step % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    model.eval()
    metrics, classACC, tgts, preds = model_eval(val_dataloader, model, args, criterion, device)

    tuning_metric = (
        metrics["micro_f1"]
    )
    scheduler.step(tuning_metric)
    is_improvement = tuning_metric > best_metric
    if is_improvement:
        best_metric = tuning_metric
        n_no_improve = 0
        torch.save(model.state_dict(), os.path.join(args.savedir, "best_model.pth"))
    else:
        n_no_improve += 1

    if n_no_improve >= args.patience:
        break

100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.6583281919793881
micro_f1: 0.2402745995423341
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.6561637158796871
micro_f1: 0.307277628032345
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.7299933785949649
micro_f1: 0.42534504391468003
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.7402004160531162
micro_f1: 0.4136532612369044
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.7503817502018686
micro_f1: 0.39297771775827145
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.7754386151405697
micro_f1: 0.43079096045197746
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.7882838751554967
micro_f1: 0.42333216905344045
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.799168126301827
micro_f1: 0.43415859346968066
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8104172066507044
micro_f1: 0.45017421602787455
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8198134064309432
micro_f1: 0.49680242342645575
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8293224682385576
micro_f1: 0.5317104420243434
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.832401965543281
micro_f1: 0.5518321327904792
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8404779951850517
micro_f1: 0.5664421310471525
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8500954344361729
micro_f1: 0.5724884080370942
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8563651318289105
micro_f1: 0.5943573667711599
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8685125822889997
micro_f1: 0.6135729779981408
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8749436416531805
micro_f1: 0.6182164392256427
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8726919653115777
micro_f1: 0.6318803690741329
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8721184368634123
micro_f1: 0.6393390530664125
-----------------------------------------------------


100%|██████████| 344/344 [03:25<00:00,  1.67it/s]


micro_auc: 0.8721184368634123
micro_f1: 0.6393390530664125
-----------------------------------------------------


In [76]:
model_state_dict = torch.load("/kaggle/working/best_model.pth")
model_config = AutoConfig.from_pretrained("/kaggle/input/medvill-weight")  
model = MultimodalBertClf.from_pretrained("/kaggle/input/medvill-weight", state_dict=model_state_dict,
                     args=args, configs=model_config, local_files_only=True).to(device)

<__main__.Args object at 0x7eb246563fd0>


In [77]:
torch.cuda.empty_cache()

In [88]:
data_iter = iter(val_dataloader)
data_batch = next(data_iter)
text_tensor, segment_tensor, mask_tensor, img_tensor, tgt_tensor =data_batch[0][2].unsqueeze(0).to(device), data_batch[1][2].unsqueeze(0).to(device), data_batch[2][2].unsqueeze(0).to(device), data_batch[3][2].unsqueeze(0).to(device), data_batch[4][2].unsqueeze(0).to(device)


In [105]:
output = model(text_tensor, mask_tensor, segment_tensor, img_tensor)
pred_bool = torch.sigmoid(output).cpu().detach().numpy() > 0.5
pred_bool

In [106]:
tgt_tensor, args.labels

In [127]:
str_disease = ",".join(diseases)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("ruslanmv/Medical-Llama3-8B")
model = AutoModelForCausalLM.from_pretrained("ruslanmv/Medical-Llama3-8B").to("cuda")  # If using GPU
# Function to format and generate response with prompt engineering using a chat template
def askme(question):
    sys_message = ''' 
    You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
    provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
    '''
    
    # Create messages structured for the chat template
    messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]

    # Applying chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)  # Adjust max_new_tokens for longer responses

    # Extract and return the generated text
    answer = tokenizer.batch_decode(outputs)[0].strip()
    return answer

# Example usage
# - Context: First describe your problem.
# - Question: Then make the question.
question = f"I am suffering from the following diseases: {diseases}, please suggest me habits I need to change as well as methods to treat these diseases."
print(askme(question))

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/449 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/729 [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

pytorch_model-00001-of-00004.bin:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

pytorch_model-00002-of-00004.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00003-of-00004.bin:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

pytorch_model-00004-of-00004.bin:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [None]:
# Function to format and generate response with prompt engineering using a chat template
def askme(question):
    sys_message = ''' 
    You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
    provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
    '''
    
    # Create messages structured for the chat template
    messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]

    # Applying chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)  # Adjust max_new_tokens for longer responses

    # Extract and return the generated text
    answer = tokenizer.batch_decode(outputs)[0].strip()
    return answer

# Example usage
# - Context: First describe your problem.
# - Question: Then make the question.
question = f"I am suffering from the following diseases: {diseases}, please suggest me habits I need to change as well as methods to treat this disease."
print(askme(question))