In [None]:
!pip install transformers torchvision


In [None]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, ViTFeatureExtractor, BertModel, ViTModel
from PIL import Image
import matplotlib.pyplot as plt
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
class CrossModalModel(nn.Module):
    def __init__(self, num_classes=4):
        super(CrossModalModel, self).__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.bert = BertModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        
        vit_hidden = self.vit.config.hidden_size
        bert_hidden = self.bert.config.hidden_size
        self.fusion = nn.Linear(vit_hidden + bert_hidden, 512)
        self.classifier = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.3)

    def forward(self, image_input, text_input):
        image_outputs = self.vit(pixel_values=image_input).last_hidden_state[:, 0, :]
        text_outputs = self.bert(**text_input).last_hidden_state[:, 0, :]

        combined = torch.cat((image_outputs, text_outputs), dim=1)
        fused = self.dropout(torch.relu(self.fusion(combined)))
        logits = self.classifier(fused)
        return logits


In [None]:
model = CrossModalModel(num_classes=4)
model.load_state_dict(torch.load("crossmodal_model.pt", map_location=device))
model.to(device)
model.eval()

tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

label_map = {0: 'glioma_tumor', 1: 'meningioma_tumor', 2: 'no_tumor', 3: 'pituitary_tumor'}


In [None]:
def predict(image_path, clinical_text, actual_label=None):
    import os
    from PIL import Image
    import matplotlib.pyplot as plt

    image = Image.open(image_path).convert("RGB")
    image_tensor = feature_extractor(images=image, return_tensors="pt")['pixel_values'].to(device)

    text_inputs = tokenizer(clinical_text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
    text_inputs = {k: v.to(device) for k, v in text_inputs.items()}

    with torch.no_grad():
        logits = model(image_tensor, text_inputs)
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = torch.max(probs).item()

    predicted_label = label_map[pred]
    image_name = os.path.basename(image_path)
    correct = (predicted_label == actual_label) if actual_label else "N/A"

    print(f"Image: {image_name}")
    if actual_label:
        print(f"Actual Label     : {actual_label}")
    print(f"Predicted Label  : {predicted_label}")
    print(f"Confidence Score : {confidence * 100:.2f}%")
    if actual_label:
        print(f"Correct Prediction: {predicted_label == actual_label}")

    # Display image
    plt.imshow(image)
    plt.axis('off')
    plt.title(f"Prediction: {predicted_label}")
    plt.show()


In [None]:
predict("Dataset/Training/glioma_tumor/gg (1).jpg", "Large mass in frontal lobe with irregular margins.")
predict("Dataset/Training/no_tumor/image (3).jpg", "No abnormal signal intensity detected.")
predict("Dataset/Training/pituitary_tumor/p (25).jpg", "Enlarged pituitary gland observed on MRI.")
