In [None]:
!pip install pycocotools torchvision nltk tqdm


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.6.0->torchvision)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86

In [None]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import Dataset

CIFAR10_LABELS = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

class CIFAR10WithText(Dataset):
    def __init__(self, train=True):
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
        self.dataset = CIFAR10(root='./data', train=train, download=True, transform=transform)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        text = CIFAR10_LABELS[label]
        return image, text


In [None]:
# 큰 모델
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, BertModel
import torchvision.models as models

class InvertibleBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.act = nn.GELU()

    def forward(self, x, reverse=False):
        if not reverse:
            x = self.linear(x)
            x = self.norm(x)
            return self.act(x)
        else:
            W = self.linear.weight
            b = self.linear.bias
            return torch.linalg.solve(W.T, (x - b).T).T

class SharedNetwork(nn.Module):
    def __init__(self, dim=768, depth=4):
        super().__init__()
        self.blocks = nn.ModuleList([InvertibleBlock(dim) for _ in range(depth)])

    def forward(self, x, reverse=False):
        for block in (reversed(self.blocks) if reverse else self.blocks):
            x = block(x, reverse=reverse)
        return x

class ConvDecoder(nn.Module):
    def __init__(self, dim=768):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 512 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (512, 7, 7)),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),  # 7 -> 14
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 14 -> 28
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),   # 28 -> 56
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),    # 56 -> 112
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),     # 112 -> 224
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)



class BiModalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_encoder = models.resnet18(weights='DEFAULT')
        self.image_encoder.fc = nn.Linear(512, 768)

        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")
        self.text_decoder = GPT2LMHeadModel.from_pretrained("gpt2")

        self.shared = SharedNetwork(dim=768)
        self.img_proj = nn.Linear(768, 768)
        self.txt_proj = nn.Linear(768, 768)
        self.image_decoder = ConvDecoder()

    def forward_image_to_text(self, image):
        img_feat = self.image_encoder(image)  # (B, 768)
        z = self.shared(img_feat, reverse=False)
        gpt_input = self.img_proj(z).unsqueeze(1)
        out = self.text_decoder(inputs_embeds=gpt_input)
        return out

    def forward_text_to_image(self, input_ids, attention_mask):
        txt_feat = self.text_encoder(input_ids, attention_mask).last_hidden_state[:, 0]  # CLS
        z = self.shared(txt_feat, reverse=True)
        return self.image_decoder(z)


In [None]:
# 개선된 작은 버전
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import DistilBertModel, DistilBertTokenizer, GPT2LMHeadModel, GPT2Tokenizer

# InvertibleBlock 및 SharedNetwork (더 크게)
class InvertibleBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.act = nn.GELU()

    def forward(self, x, reverse=False):
        if not reverse:
            return self.act(self.norm(self.linear(x)))
        else:
            W = self.linear.weight
            b = self.linear.bias
            return torch.linalg.solve(W.T, (x - b).T).T

class SharedNetwork(nn.Module):
    def __init__(self, dim=2048, depth=6):
        super().__init__()
        self.blocks = nn.ModuleList([InvertibleBlock(dim) for _ in range(depth)])

    def forward(self, x, reverse=False):
        for block in (reversed(self.blocks) if reverse else self.blocks):
            x = block(x, reverse=reverse)
        return x

# 간단한 Image Encoder (작게)
class SimpleImageEncoder(nn.Module):
    def __init__(self, output_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1),  # 32x16x16
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),  # 64x8x8
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# 간단한 ConvDecoder
class ConvDecoder(nn.Module):
    def __init__(self, dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 256 * 4 * 4),
            nn.ReLU(),
            nn.Unflatten(1, (256, 4, 4)),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),  # 32x32
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

class BiModalModel(nn.Module):
    def __init__(self):
        super().__init__()
        # self.image_encoder = SimpleImageEncoder(output_dim=512)
        self.image_encoder = models.resnet18(weights='DEFAULT')
        self.image_encoder.fc = nn.Linear(512, 512)
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.text_decoder = GPT2LMHeadModel.from_pretrained("distilgpt2")

        # projection to shared space (dim=2048)
        self.to_shared_img = nn.Linear(512, 2048)
        self.to_shared_txt = nn.Linear(768, 2048)

        self.shared = SharedNetwork(dim=2048, depth=6)

        self.to_gpt_embed = nn.Linear(2048, self.text_decoder.config.n_embd)
        self.image_decoder = ConvDecoder(dim=2048)

    def forward_image_to_text(self, image):
        img_feat = self.image_encoder(image)              # (B, 512)
        shared = self.shared(self.to_shared_img(img_feat), reverse=False)
        gpt_input = self.to_gpt_embed(shared).unsqueeze(1)
        return self.text_decoder(inputs_embeds=gpt_input)

    def forward_text_to_image(self, input_ids, attention_mask):
        txt_feat = self.text_encoder(input_ids, attention_mask).last_hidden_state[:, 0]
        shared = self.shared(self.to_shared_txt(txt_feat), reverse=True)
        return self.image_decoder(shared)



In [None]:

import torch
import torch.nn.functional as F
from transformers import BertTokenizer, GPT2Tokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 토크나이저 로드
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token  # pad_token 세팅

# 모델 초기화
model = BiModalModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 데이터셋 및 로더
dataset = CIFAR10WithText(train=True)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

os.makedirs("checkpoints", exist_ok=True)  # 저장 폴더 만들기

best_loss = float("inf")
best_acc = 0.0

model.train()
for epoch in range(5):
    total_correct = 0
    total_tokens = 0
    total_loss = 0.0

    pbar = tqdm(loader, desc=f"Epoch {epoch+1}")
    for images, texts in pbar:
        images = images.to(device)

        # 1) 이미지 → 텍스트
        gpt_enc = gpt_tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
        input_ids = gpt_enc.input_ids.to(device)
        attention_mask = gpt_enc.attention_mask.to(device)

        out = model.forward_image_to_text(images)
        logits = out.logits
        min_len = min(logits.size(1), input_ids.size(1))

        # 손실 계산
        loss1 = F.cross_entropy(
            logits[:, :min_len, :].reshape(-1, logits.size(-1)),
            input_ids[:, :min_len].reshape(-1),
            ignore_index=gpt_tokenizer.pad_token_id
        )

        # 정확도 계산
        preds = torch.argmax(logits[:, :min_len, :], dim=-1)
        mask = (input_ids[:, :min_len] != gpt_tokenizer.pad_token_id)

        correct = (preds == input_ids[:, :min_len]) & mask
        total_correct += correct.sum().item()
        total_tokens += mask.sum().item()

        # 2) 텍스트 → 이미지
        bert_enc = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
        '''
        txt_feat = model.text_encoder(input_ids=bert_enc.input_ids, attention_mask=bert_enc.attention_mask).last_hidden_state[:, 0]

        shared_feat = model.shared(txt_feat, reverse=True)
        recon_img = model.image_decoder(shared_feat)'''

        recon_img = model.forward_text_to_image(bert_enc.input_ids, bert_enc.attention_mask)

        loss2 = F.mse_loss(recon_img, images)

        # 최종 손실
        loss = loss1 + loss2
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # tqdm 출력 업데이트
        acc = total_correct / total_tokens * 100 if total_tokens > 0 else 0
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'loss1': f"{loss1.item():.4f}",
            'loss2': f"{loss2.item():.4f}",
            'acc': f"{acc:.2f}%"
        })

    # 🔽 체크포인트 저장
    avg_loss = total_loss / len(loader)
    final_acc = total_correct / total_tokens * 100

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), "checkpoints/best_loss.pt")

    if final_acc > best_acc:
        best_acc = final_acc
        torch.save(model.state_dict(), "checkpoints/best_acc.pt")

    torch.save(model.state_dict(), "checkpoints/last.pt")
    print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f} | Text Accuracy: {final_acc:.2f}%")


Epoch 1:  38%|███▊      | 1189/3125 [1:29:50<2:51:34,  5.32s/it, loss=1.6546, loss1=1.3550, loss2=0.2996, acc=41.93%]