<a href="https://colab.research.google.com/github/willt08/rosa/blob/main/_rosa_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **Install dependencies**

In [None]:
!pip install torch transformers

### **Define the Rosa model**

In [None]:
from transformers import BertModel, BertTokenizer
import torch
import torch.nn as nn

# Define the ROSA model class
class Rosa(nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_emotions=28):
        super().__init__()
        self.heart = BertModel.from_pretrained(model_name)
        self.grace = nn.Dropout(0.3)
        self.bloom = nn.Linear(self.heart.config.hidden_size, num_emotions)

    def forward(self, input_ids, attention_mask):
        petals = self.heart(input_ids=input_ids, attention_mask=attention_mask)
        pooled = petals.pooler_output
        softened = self.grace(pooled)
        logits = self.bloom(softened)
        return logits


### **Load the model weights from Hugging Face**

In [None]:
# Download model weights (.pt) from your Hugging Face model repo
!wget https://huggingface.co/willt-dc/Rosa-V1/resolve/main/rosa.pt

### **Tokenize text and Predict emotions**

In [None]:
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = Rosa(num_emotions=28)
model.load_state_dict(torch.load("rosa.pt", map_location=torch.device("cpu")))
model.eval()

# Run inference on your text
def predict(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    model_inputs = {
    "input_ids": inputs["input_ids"],
    "attention_mask": inputs["attention_mask"]
}
    with torch.no_grad():
        logits = model(**model_inputs)
        probs = torch.sigmoid(logits).squeeze()
    return probs

# Example
text = "And all I loved, I loved alone"
probs = predict(text)

# Emotion labels
emotion_labels = [
    "admiration", "amusement", "anger", "annoyance", "approval", "caring",
    "confusion", "curiosity", "desire", "disappointment", "disapproval",
    "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
    "joy", "love", "nervousness", "optimism", "pride", "realization", "relief",
    "remorse", "sadness", "surprise", "neutral"
]

# Print nicely
for label, prob in zip(emotion_labels, probs.tolist()):
    print(f"{label:<15}: {prob:.4f}")
