In [None]:
import torch
import torch.nn as nn
from torchvision import models
from transformers import AutoModel, AutoTokenizer

class MultiModalAneurysmClassifier(nn.Module):
    def __init__(self, text_model_name="bert-base-uncased", image_model_name="resnet18", hidden_dim=512):
        super().__init__()
        
        # Image encoder
        image_model = models.__dict__[image_model_name](pretrained=True)
        self.image_encoder = nn.Sequential(*list(image_model.children())[:-1])  # remove FC layer
        self.image_feature_dim = image_model.fc.in_features
        
        # Text encoder
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(text_model_name)
        self.text_feature_dim = self.text_encoder.config.hidden_size
        
        # Fusion MLP
        self.fusion_layer = nn.Sequential(
            nn.Linear(self.image_feature_dim + self.text_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        
        # Final classifier (22 binary outputs)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 8, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 22),
            nn.Sigmoid()
        )

    def forward(self, images, texts):
        """
        images: [B, 8, 3, H, W]
        texts: list of 8 * B text strings
        """
        B = images.size(0)
        all_fused = []

        for i in range(8):
            img = images[:, i, :, :, :]  # [B, 3, H, W]
            img_feat = self.image_encoder(img).squeeze(-1).squeeze(-1)  # [B, image_feature_dim]
            
            txt_batch = [t[i] for t in texts]  # List[B]
            tokens = self.tokenizer(txt_batch, return_tensors="pt", padding=True, truncation=True).to(images.device)
            txt_feat = self.text_encoder(**tokens).last_hidden_state[:, 0, :]  # CLS token
            
            fused = torch.cat([img_feat, txt_feat], dim=1)  # [B, image+text]
            fused = self.fusion_layer(fused)  # [B, hidden_dim]
            all_fused.append(fused)

        combined = torch.cat(all_fused, dim=1)  # [B, hidden_dim * 8]
        out = self.classifier(combined)  # [B, 22]
        return out


In [None]:
def make_text_descriptions(index):
    return [
        "Left internal carotid artery injection, view A",
        "Left internal carotid artery injection, view B",
        "Right internal carotid artery injection, view A",
        "Right internal carotid artery injection, view B",
        "Left vertebral artery injection, view A",
        "Left vertebral artery injection, view B",
        "Right vertebral artery injection, view A",
        "Right vertebral artery injection, view B",
    ]


In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
# csv_path = "/home/edlab/sjim/k-ium-coding-vessels/train_set/train.csv"
# image_dir = "/home/edlab/sjim/k-ium-coding-vessels/train_set/images"
class AneurysmDataset(Dataset):
    def __init__(self, csv_path, image_dir, tokenizer, transform=None):
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        
        self.image_order = ['LI-A', 'LI-B', 'RI-A', 'RI-B', 'LV-A', 'LV-B', 'RV-A', 'RV-B']
        self.text_map = {
            'LI-A': "Left internal carotid artery injection, view A",
            'LI-B': "Left internal carotid artery injection, view B",
            'RI-A': "Right internal carotid artery injection, view A",
            'RI-B': "Right internal carotid artery injection, view B",
            'LV-A': "Left vertebral artery injection, view A",
            'LV-B': "Left vertebral artery injection, view B",
            'RV-A': "Right vertebral artery injection, view A",
            'RV-B': "Right vertebral artery injection, view B",
        }
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        patient_id = str(row["Index"])
        
        # 1. 이미지 로딩
        images = []
        for suffix in self.image_order:
            image_path = os.path.join(self.image_dir, f"{patient_id}_{suffix}.jpg")
            image = Image.open(image_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            images.append(image)
        images = torch.stack(images)  # [8, 3, H, W]

        # 2. 텍스트 설명 리스트
        texts = [self.text_map[suffix] for suffix in self.image_order]

        # 3. 레이블
        label = torch.tensor(row.values[1:], dtype=torch.float)  # [22]
        
        return images, texts, label


In [None]:
from transformers import AutoTokenizer

# 텍스트 tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# 이미지 전처리
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet 평균
                         std=[0.229, 0.224, 0.225])
])

# Dataset & DataLoader
dataset = AneurysmDataset(
    csv_path="train.csv",
    image_dir="images",
    tokenizer=tokenizer,
    transform=image_transform
)

dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)


In [None]:
# train
import os
import wandb
import torch
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer
from torchvision import transforms
from tqdm import tqdm

# from model import MultiModalAneurysmClassifier  # 위에서 만든 모델 클래스
# from dataset import AneurysmDataset             # 위에서 만든 Dataset 클래스

# -------------------- Settings -------------------- #
CSV_PATH = "train.csv"
IMAGE_DIR = "images"
TEXT_MODEL_NAME = "bert-base-uncased"
IMAGE_MODEL_NAME = "resnet18"
EPOCHS = 3
BATCH_SIZE = 4
LR = 1e-4
VAL_INTERVAL = 1  # validate every n steps
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

wandb.init(project="aneurysm-multimodal", config={
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "lr": LR,
    "image_model": IMAGE_MODEL_NAME,
    "text_model": TEXT_MODEL_NAME
})

# -------------------- Load Data -------------------- #
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

full_dataset = AneurysmDataset(CSV_PATH, IMAGE_DIR, tokenizer, transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# -------------------- Init Model -------------------- #
model = MultiModalAneurysmClassifier(TEXT_MODEL_NAME, IMAGE_MODEL_NAME).to(DEVICE)
criterion = nn.BCELoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

# -------------------- Validation Loop -------------------- #
@torch.no_grad()
def evaluate():
    model.eval()
    total_loss, total_correct = 0, 0
    total = 0
    for images, texts, labels in val_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images, texts)
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        
        preds = (outputs > 0.5).float()
        total_correct += (preds == labels).float().sum().item()
        total += labels.numel()
    accuracy = total_correct / total
    return total_loss / len(val_loader), accuracy

# -------------------- Training Loop -------------------- #
model.train()
global_step = 0
for epoch in range(EPOCHS):
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for images, texts, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images, texts)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # logging
        wandb.log({"train/loss": loss.item(), "step": global_step})
        pbar.set_postfix(loss=loss.item())

        # validation
        if global_step % VAL_INTERVAL == 0:
            val_loss, val_acc = evaluate()
            wandb.log({
                "val/loss": val_loss,
                "val/accuracy": val_acc,
                "step": global_step
            })

        global_step += 1

# Save model
torch.save(model.state_dict(), "aneurysm_model.pth")
wandb.save("aneurysm_model.pth")
