In [None]:
from transformers import AutoTokenizer, AutoModel


import torch
from torch import nn

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

In [None]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)

In [None]:

import pandas as pd
bcr = pd.read_parquet("../data_dir/bcr.parquet")
bcr.head()
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
le.fit(bcr["target"])
bcr["label"] = le.transform(bcr["target"])
bcr.head()

In [None]:
from sklearn.model_selection import train_test_split
bcr_train, bcr_val = train_test_split(bcr, stratify=bcr["label"], random_state=0)

bcr_train["label"].mean(), bcr_val["label"].mean()
bcr_train.shape, bcr_val.shape

bcr_train.reset_index(inplace=True, drop=True)
bcr_val.reset_index(inplace=True, drop=True)
bcr_train.head()

In [None]:
from torch.utils.data import Dataset
class BCRDataset(Dataset):
    def __init__(self, df, tokenizer):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        x = self.df.loc[i, "sequence"]
        encoded = self.tokenizer(x, return_tensors="pt", padding="max_length", truncation=True, max_length=320)
        y = self.df.loc[i, "label"]
        y = torch.tensor(y, dtype=torch.long)

        # "squeeze"して次元を減らす
        input_ids = encoded["input_ids"].squeeze(0)        # shape: (seq_len)
        attention_mask = encoded["attention_mask"].squeeze(0)  # shape: (seq_len)

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": y}

train_ds = BCRDataset(bcr_train, tokenizer)
val_ds = BCRDataset(bcr_val, tokenizer)

In [None]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)


In [None]:
class ESMClassifier(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.classifier = nn.Linear(320, 2)  # 注意！ esm2_t6_8M_UR50D の hidden_size は 320

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        cls_token = outputs.last_hidden_state[:, 0, :]  # [CLS] トークンだけ取り出し
        logits = self.classifier(cls_token)
        return logits

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bcr_model = ESMClassifier(model=model).to(device)

# オプティマイザと損失関数
optimizer = torch.optim.Adam(bcr_model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()



In [None]:
from sklearn.metrics import accuracy_score

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = outputs.argmax(dim=1)
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    acc = accuracy_score(all_labels, all_preds)

    return total_loss / len(loader), acc


def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)

            total_loss += loss.item()

            preds = outputs.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    acc = accuracy_score(all_labels, all_preds)

    return total_loss / len(loader), acc


In [None]:
for epoch in range(10):
    train_loss, train_acc = train_one_epoch(bcr_model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(bcr_model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Train Acc = {train_acc:.4f} | Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}")