In [1]:
%pip install -U sentence-transformers datasets torch


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import ir_datasets
from tqdm import tqdm

class MSMARCODataset(Dataset):
    def __init__(self, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        dataset = ir_datasets.load("msmarco-passage/train/triples-small")
        
        queries = {q.query_id: q.text for q in tqdm(dataset.queries_iter(), desc="Loading queries")}
        docs = {d.doc_id: d.text for d in tqdm(dataset.docs_iter(), desc="Loading docs")}
        
        self.data = []
        for i, item in enumerate(tqdm(dataset.docpairs_iter(), desc="Loading pairs")):
            self.data.append({
                'query': queries[item.query_id],
                'pos_doc': docs[item.doc_id_a],
                'neg_doc': docs[item.doc_id_b]
            })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        query = self.tokenizer(item['query'], truncation=True, padding='max_length', 
                               max_length=self.max_length, return_tensors='pt')
        pos_doc = self.tokenizer(item['pos_doc'], truncation=True, padding='max_length',
                                 max_length=self.max_length, return_tensors='pt')
        neg_doc = self.tokenizer(item['neg_doc'], truncation=True, padding='max_length',
                                 max_length=self.max_length, return_tensors='pt')
        return {
            'q_ids': query['input_ids'].squeeze(),
            'q_mask': query['attention_mask'].squeeze(),
            'p_ids': pos_doc['input_ids'].squeeze(),
            'p_mask': pos_doc['attention_mask'].squeeze(),
            'n_ids': neg_doc['input_ids'].squeeze(),
            'n_mask': neg_doc['attention_mask'].squeeze()
        }

class SPLADE(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased'):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        for param in self.bert.parameters():
            param.requires_grad = False
        self.linear = nn.Linear(768, 30522)
        
    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.linear(output.last_hidden_state)
        weights = torch.log(1 + torch.relu(logits))
        weights = torch.sum(weights * attention_mask.unsqueeze(-1), dim=1)
        return weights

def compute_loss(q_reps, p_reps, n_reps, lambda_reg=0.0001):
    batch_size = q_reps.size(0)
    
    pos_scores = torch.sum(q_reps * p_reps, dim=1)
    neg_scores = torch.sum(q_reps * n_reps, dim=1)
    
    all_docs = torch.cat([p_reps, n_reps], dim=0)
    all_scores = torch.matmul(q_reps, all_docs.T)
    
    labels = torch.arange(batch_size, device=q_reps.device)
    ce_loss = F.cross_entropy(all_scores, labels)
    
    l1_q = lambda_reg * torch.mean(torch.sum(torch.abs(q_reps), dim=1))
    l1_d = lambda_reg * torch.mean(torch.sum(torch.abs(all_docs), dim=1))
    
    return ce_loss + l1_q + l1_d

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = SPLADE().to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)

dataset = MSMARCODataset(tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model.train()
for epoch in range(3):
    total_loss = 0
    running_loss = 0
    for i, batch in enumerate(tqdm(dataloader, desc=f'Epoch {epoch+1}')):
        optimizer.zero_grad()
        
        q_reps = model(batch['q_ids'].to(device), batch['q_mask'].to(device))
        p_reps = model(batch['p_ids'].to(device), batch['p_mask'].to(device))
        n_reps = model(batch['n_ids'].to(device), batch['n_mask'].to(device))
        
        loss = compute_loss(q_reps, p_reps, n_reps)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        running_loss += loss.item()
        
        if (i + 1) % 1000 == 0:
            tqdm.write(f'Step {i+1}: avg loss = {running_loss/1000:.4f}')
            running_loss = 0
    
    tqdm.write(f'Epoch {epoch+1} avg loss: {total_loss/len(dataloader):.4f}')

torch.save(model.state_dict(), 'splade_v1.pt')


Loading queries: 808731it [00:01, 556048.21it/s]
Loading docs: 8841823it [00:25, 346259.14it/s]
Loading pairs: 39780811it [03:14, 204265.52it/s]
Epoch 1:   0%|          | 0/1243151 [00:00<?, ?it/s]

{'q_ids': tensor([[ 101, 2073, 2515,  ...,    0,    0,    0],
        [ 101, 6412, 1997,  ...,    0,    0,    0],
        [ 101, 2129, 2146,  ...,    0,    0,    0],
        ...,
        [ 101, 2779, 3199,  ...,    0,    0,    0],
        [ 101, 4633, 1999,  ...,    0,    0,    0],
        [ 101, 2054, 2828,  ...,    0,    0,    0]]), 'q_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'p_ids': tensor([[  101,  2073,  2515,  ...,     0,     0,     0],
        [  101,  1996,  2942,  ...,     0,     0,     0],
        [  101,  1996,  3452,  ...,     0,     0,     0],
        ...,
        [  101,  2045,  2024,  ...,     0,     0,     0],
        [  101, 24800,  1021,  ...,     0,     0,     0],
        [  101,  3763, 10925,  ...,  1012,   102,     0]]), 'p_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1, 

Epoch 1:   0%|          | 1/1243151 [00:09<3433:33:18,  9.94s/it]

{'q_ids': tensor([[  101,  2054,  3361,  ...,     0,     0,     0],
        [  101, 26568, 27303,  ...,     0,     0,     0],
        [  101,  2054,  2003,  ...,     0,     0,     0],
        ...,
        [  101,  2054,  1005,  ...,     0,     0,     0],
        [  101,  2129,  2146,  ...,     0,     0,     0],
        [  101,  2129,  2172,  ...,     0,     0,     0]]), 'q_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'p_ids': tensor([[  101,  7079,  2115,  ...,     0,     0,     0],
        [  101,  8856,  4098,  ...,     0,     0,     0],
        [  101,  1996,  3177,  ...,     0,     0,     0],
        ...,
        [  101,  2769,  2915,  ...,     0,     0,     0],
        [  101, 20228,  2015,  ...,     0,     0,     0],
        [  101,  1996,  8417,  ...,  1996,  5025,   102]]), 'p_mask': tensor([[1, 1, 1

Epoch 1:   0%|          | 1/1243151 [00:16<5702:18:17, 16.51s/it]


KeyboardInterrupt: 

In [48]:
print(tokenizer.decode(batch['q_ids'][2], skip_special_tokens=True))
print(tokenizer.decode(batch['p_ids'][2], skip_special_tokens=True))
print(tokenizer.decode(batch['n_ids'][2], skip_special_tokens=True))

what is sound times speed
the speed of light is much faster than the speed of sound. sound moves at 343 metres per second ( thats about 770 miles per hour! ). the speed of light is a little trickier. in a vacuum it is roughly 300 000 000 metres per second ( so nearly 900 000 times faster than sound! ). however, when light moves through air or glass it gets slowed down a little bit so light moving through glass moves at 200 000 000 metres per second which is still 580 000 times faster than sound.
what determines the pitch of sound? pitch is determined by the frequency of a wave, and frequency is the combination of wavelength and speed at which the wave is traveling. sound has a constant speed of 343 meters per second, so wavelength dictates pitch. the longer a sound ' s wavelength, the lower the pitch of that sound.


In [19]:
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')

reverse_voc = {v: k for k, v in tokenizer.vocab.items()}


NameError: name 'x' is not defined