In [3]:
!pip install lightning datasets transformers pytorch-lightning wandb -qq

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2024.12.0 which is incompatible.[0m[31m
[0m

In [1]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from datasets import load_dataset
import lightning as L
from pytorch_lightning.loggers import WandbLogger
import wandb

In [2]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msawantch099[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
class LitTextClassification(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased")

        # 🔒 Freeze BERT base model parameters
        for param in self.model.bert.parameters():
            param.requires_grad = False

        # ✅ Optional: print trainable parameter count
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"Trainable parameters: {trainable_params / 1e6:.2f}M")

    def training_step(self, batch):
        output = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["label"],
        )
        self.log("train_loss", output.loss)
        return output.loss

    def configure_optimizers(self):
        # Slightly higher LR is okay since only head is training
        return torch.optim.Adam(self.model.parameters(), lr=1e-4)

In [4]:
class TextClassificationData(L.LightningDataModule):
    def prepare_data(self):
        load_dataset("imdb")

    def train_dataloader(self):
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        dataset = load_dataset("imdb")["train"]
        dataset = dataset.map(lambda sample: tokenizer(sample["text"], padding="max_length", truncation=True), batched=True)
        dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
        return torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)


In [5]:
model = LitTextClassification()
data = TextClassificationData()
trainer = L.Trainer(logger=WandbLogger(project="Transfer Learning"), max_epochs=3, accelerator="gpu")
trainer.fit(model, data)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Trainable parameters: 0.00M


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name  | Type                          | Params | Mode
---------------------------------------------------------------
0 | model | BertForSequenceClassification | 109 M  | eval
---------------------------------------------------------------
1.5 K     Trainable params
109 M     Non-trainable params
109 M     Total params
437.935   Total estimated model params size (MB)
0         Modules in train mode
231       Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name  | Type                          | Params | Mode
---------------------------------------------------------------
0 | model | BertForSequenceClassification | 109 M  | eval
---------------------------------------------------------------
1.5 K     Trainable params
109 M     Non-trainable params
109 M     Total params
437.935   Total estimated model params size (MB)
0  

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


In [6]:
def predict_sentiment(text, checkpoint_path=None):
    # Load model
    if checkpoint_path:
        model = LitTextClassification.load_from_checkpoint(checkpoint_path, map_location="cpu")
    else:
        model = LitTextClassification()
        # Freeze BERT base if not loading from checkpoint (for consistency)
        for param in model.model.bert.parameters():
            param.requires_grad = False

    model.eval()

    # Tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)

    # Run inference
    with torch.no_grad():
        outputs = model.model(**inputs)

    logits = outputs.logits
    probs = torch.softmax(logits, dim=-1)
    pred_class = torch.argmax(probs, dim=1).item()

    class_labels = {
        0: "negative",
        1: "positive",
    }

    print(f"Text: {text}")
    print(f"Output probabilities: {probs}")
    print(f"Predicted class: {pred_class}")
    print(f"Predicted label: {class_labels[pred_class]}")

    return class_labels[pred_class], probs

In [18]:
text_input = "The staff was incredibly rude and unhelpful, and I felt completely ignored"
checkpoint = "/content/Transfer Learning/jw8deeew/checkpoints/epoch=2-step=4689.ckpt"  # Update with your actual path

predict_sentiment(text_input, checkpoint)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 0.00M
Text: The staff was incredibly rude and unhelpful, and I felt completely ignored
Output probabilities: tensor([[0.7273, 0.2727]])
Predicted class: 0
Predicted label: negative


('negative', tensor([[0.7273, 0.2727]]))