In [1]:
import os

folder_path = "CLIP"  # Replace with the actual path to your IMG folder

# Get all files in the folder
files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, file))]

# Write file names and paths to output.txt
output_file_path = "ouput_drug.txt"  # Replace with the desired output file path
with open(output_file_path, 'w') as output_file:
    for file_path in files:
        output_file.write(file_path + '\n')

print(f"File names and paths exported to {output_file_path}")

File names and paths exported to ouput_drug.txt


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from transformers import BertTokenizer, BertModel
from torchvision.models import resnet18
import torch.nn as nn

text_list = []
import re

def remove_parentheses(file_name):
    # Sử dụng regular expression để loại bỏ các ngoặc đơn và nội dung bên trong chúng
    return re.sub(r'\([^)]*\)', '', file_name)

text_list = []
image_paths = []

with open('ouput_drug.txt', 'r') as file:
    for line in file:
        # Lấy tên file từ đường dẫn và loại bỏ ngoặc đơn
        file_name = remove_parentheses(line.strip().split("/")[-1].split(".")[0])
        
        # Thêm tên file vào danh sách text_list
        text_list.append(file_name)
        
        # Thêm đường dẫn vào danh sách image_paths
        image_paths.append(line.strip())

# In danh sách text_list và image_paths
print("text_list =", text_list)
print("image_paths =", image_paths)

text_list = ['valproate', 'hydroxyurea', 'benzromarone', 'praziquantel', 'xanax', 'atrovent', 'botox', 'co', 'disulfiram', 'alprazolam', 'carbidopa', 'acetaminophen', 'corticosteroid', 'octreotide', 'voclosporin', 'propecia', 'dextromethorphan', 'acamprosate', 'fluoxetine', 'valganciclovir', 'levofloxacin', 'nalfon', 'phenytoin', 'colchicine', 'norpace', 'trileptal', 'clarithromycin', 'auranofin', 'methadone', 'cevimeline', 'clopidogrel', 'voriconazole', 'phenobarbital', 'sildenafil', 'potassium_gluconate', 'tofacitinib', 'efavirenz', 'exjade', 'cephalexin', 'nimodipine', 'dexmedetomidine', 'brevibloc', 'flucytosine', 'ciprofloxacin', 'deucravacitinib', 'amoxicillin_sulbactam', 'captopril', 'digoxin', 'amantadine', 'digoxin', 'factive', 'buprenorphine', 'prednisolone', 'busulfan', 'cefixime', 'tiagabine', 'propranolol ', 'ticagrelor', 'kcl', 'infliximab', 'ethyol', 'neurontin', 'methylene_blue', 'ibuprofen', 'guafacine', 'abilify', 'triazolam', 'avastin', 'doxycycline', 'lexapro', 'spi

In [3]:
import torch
from PIL import Image
from torchvision import transforms
from transformers import BertTokenizer, BertModel
from torchvision.models import resnet18
import torch.nn as nn

# 定义Clip模型
class ClipModel(nn.Module):
    def __init__(self, text_encoder_name='bert-base-uncased'):
        super(ClipModel, self).__init__()

        # Load pre-trained BERT
        self.text_encoder = BertModel.from_pretrained(text_encoder_name)
        self.tokenizer = BertTokenizer.from_pretrained(text_encoder_name)

        # Freeze BERT parameters
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        # Load pre-trained ResNet18
        self.image_encoder = resnet18(pretrained=True)
        # Remove the fully connected layer of ResNet18
        self.image_encoder = nn.Sequential(*list(self.image_encoder.children())[:-1])

        # Additional layer for mapping ResNet18 output to BERT output dimensions
        self.mapping_layer = nn.Linear(512, self.text_encoder.config.hidden_size)

    def forward(self, text, image):
        # Text encoding with BERT
        text_inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        text_outputs = self.text_encoder(**text_inputs, return_dict=True)
        text_pooler_output = text_outputs.pooler_output  # Using pooler output

        # Image encoding with ResNet18
        image_outputs = self.image_encoder(image)
        image_outputs = image_outputs.view(image_outputs.size(0), -1)  # Flatten the output

        # Mapping ResNet18 output to BERT output dimensions
        mapped_outputs = self.mapping_layer(image_outputs)

        return text_pooler_output, mapped_outputs

def inference(image_path, clip_model, text_list):
    # 读取图像
    image = Image.open(image_path).convert('RGB')

    # 数据转换
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    image = transform(image)

    # 将图像移至GPU（如果可用）
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image = image.to(device)

    # 将模型移至GPU（如果可用）
    clip_model = clip_model.to(device)

    # 使用模型进行前向传播
    with torch.no_grad():
        _, image_output = clip_model("", image.unsqueeze(0))

    # 计算输入图像与每个训练文本之间的相似度
    similarities = []
    text_list.sort()

    for text in text_list:
        text_inputs = clip_model.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        text_outputs = clip_model.text_encoder(**text_inputs, return_dict=True)
        text_pooler_output = text_outputs.pooler_output

        # 计算余弦相似度
        similarity = torch.nn.functional.cosine_similarity(image_output.squeeze(dim=1), text_pooler_output.squeeze(dim=1))
        similarities.append(similarity.item())

    # 找到相似度最大的文本
    max_similarity_index = similarities.index(max(similarities))
    most_similar_text = text_list[max_similarity_index]

    print(text_list)
    print(similarities)
    print(most_similar_text)
    return most_similar_text

# 加载训练好的模型权重
trained_model_path = 'Model_CLIP/clip_model_epoch99.pth'  # 替换为你的模型权重路径
clip_model = ClipModel()
clip_model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')))
clip_model.eval()

# 替换为你想要推断的图片路径
inference_image_path = 'CLIP/abatacept.png'
result_text = inference(inference_image_path, clip_model, text_list)

print(f"The most similar text for the given image is: {result_text}")
#print(result_text)





['abatacept', 'abilify', 'acamprosate', 'acarbose', 'acetaminophen', 'acetazolamide', 'acyclovir', 'albendazole', 'aliskiren', 'allopurinol', 'alprazolam', 'amantadine', 'amlodipine', 'amoxicillin_sulbactam', 'anakinra', 'aprepitant', 'aquamephyton', 'arimidex', 'atropine', 'atrovent', 'auranofin', 'avastin', 'azathioprine', 'benzromarone', 'bisoprolol', 'botox', 'brevibloc', 'buprenorphine', 'busulfan', 'busulfex', 'captopril', 'carbamazepine', 'carbidopa', 'cefixime', 'celecoxib', 'cephalexin', 'cevimeline', 'cholestyramine', 'cilostazol', 'cimetidine', 'cinacalcet', 'ciprofloxacin', 'clarithromycin', 'clonidine', 'clopidogrel', 'clozapine', 'co', 'colchicine', 'colchicine', 'coreg', 'corticosteroid', 'corticosteroids', 'crixivan', 'cuprimine', 'cycloserine', 'danazol', 'dapsone', 'deucravacitinib', 'dexmedetomidine', 'dextromethorphan', 'diclofenac', 'digoxin', 'digoxin', 'diphenhydramine', 'diphenidol', 'diphenidol', 'disulfiram', 'dobutamine', 'doxepin', 'doxycycline', 'efavirenz'

In [None]:
0.6605501413345337
0.38850514888763427
0.23675904870033265
0.1462131842970848
0.09409026354551316
0.06461078375577926
0.04853979051113129
0.03967658579349518
0.034734292700886724
0.03203148581087589
0.029348791763186454
0.02780153863132
0.026392125338315964
0.025132407248020173
0.023967040330171586
0.023512064665555953
0.022044331207871436
0.02057553008198738
0.01916919872164726
0.018795470520853997
0.017295506224036215
0.015917360410094263
0.015243663266301156
0.014299644716084003
0.013275043666362762
0.013265021890401841
0.012687429413199424
0.011689359322190285
0.011487093195319176
0.010823146440088749
0.010088427364826203
0.01003305371850729
0.009321731701493263
0.00908319428563118
0.00878697857260704
0.008741279318928718
0.008370254933834077
0.007961973827332259
0.007733644265681505
0.007220143172889948
0.007148622069507837
0.006704780645668507
0.006691183056682348
0.00665165176615119
0.00843109590932727
0.006463685911148787
0.006265654880553484
0.0063707692548632625
0.00634037172421813
0.005970792658627033
0.006004357989877462
0.005827054940164089
0.006455147825181484
0.005471691396087408
0.005302015878260136
0.0061350538395345214
0.005555207468569278
0.005673058610409498
0.005200700368732214
0.006273916643112898
0.004851007275283337
0.004962639044970274
0.0054004877805709835
0.004904125537723303
0.005557434633374214
0.004786490090191364
0.004774809535592795
0.004815138038247823
0.005124165955930948
0.005143087124451995
0.005733414459973574
0.004859576933085918
0.004574669292196631
0.005126507300883531
0.00426047183573246
0.004532157024368644
0.004463887214660645
0.004262052243575454
0.0039418132044374945
0.004691322287544608
0.004420235008001328
0.003768266411498189
0.003860887186601758
0.003746120957657695
0.004152480652555824
0.0038853527512401342
0.004234118154272437
0.0036459486465901135
0.004245704645290971
0.004662737809121608
0.003711740905418992
0.004162572696805001
0.004161321185529232
0.003774043684825301
0.0034573666285723447
0.003563568741083145
0.004494583280757069
0.0033989906776696445
0.003655323339626193
0.003437324892729521