In [1]:
%pip install -U sentence-transformers datasets torch ir_datasets transformers

Defaulting to user installation because normal site-packages is not writeable
Collecting sentence-transformers
  Downloading sentence_transformers-5.1.2-py3-none-any.whl.metadata (16 kB)
Collecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting torch
  Downloading torch-2.9.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting transformers
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
Collecting huggingface-hub>=0.20.0 (from sentence-transformers)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface-hub>=0.20.0->sentence-transformers)
  Downloading hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downlo

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import ir_datasets
from tqdm import tqdm

In [None]:
class MSMARCODataset(Dataset):
    def __init__(self, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        dataset = ir_datasets.load("msmarco-passage/train/triples-small")
        
        queries = {q.query_id: q.text for q in tqdm(dataset.queries_iter(), desc="Loading queries")}
        docs = {d.doc_id: d.text for d in tqdm(dataset.docs_iter(), desc="Loading docs")}
        
        self.data = []
        for i, item in enumerate(tqdm(dataset.docpairs_iter(), desc="Loading pairs")):
            self.data.append({
                'query': queries[item.query_id],
                'pos_doc': docs[item.doc_id_a],
                'neg_doc': docs[item.doc_id_b]
            })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        query = self.tokenizer(item['query'], truncation=True, padding='max_length', 
                               max_length=self.max_length, return_tensors='pt')
        pos_doc = self.tokenizer(item['pos_doc'], truncation=True, padding='max_length',
                                 max_length=self.max_length, return_tensors='pt')
        neg_doc = self.tokenizer(item['neg_doc'], truncation=True, padding='max_length',
                                 max_length=self.max_length, return_tensors='pt')
        return {
            'q_ids': query['input_ids'].squeeze(),
            'q_mask': query['attention_mask'].squeeze(),
            'p_ids': pos_doc['input_ids'].squeeze(),
            'p_mask': pos_doc['attention_mask'].squeeze(),
            'n_ids': neg_doc['input_ids'].squeeze(),
            'n_mask': neg_doc['attention_mask'].squeeze()
        }

    
dataset = MSMARCODataset(tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [44]:
class SPLADE(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased'):
        super().__init__()
        self.bert_mlm = AutoModelForMaskedLM.from_pretrained(model_name)
    def forward(self, input_ids, attention_mask):
        outputs = self.bert_mlm(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        relu_logits = F.relu(logits)
        relu_logits = relu_logits * attention_mask.unsqueeze(-1)
        pooled, _ = torch.max(relu_logits, dim=1)
        weights = torch.log1p(pooled)
        return weights

def compute_loss(q_reps, p_reps, n_reps, lambda_q=1e-3, lambda_d=1e-4, tau=0.07):
    batch_size = q_reps.size(0)
    all_docs = torch.cat([p_reps, n_reps], dim=0)
    scores = torch.matmul(q_reps, all_docs.T) / tau
    labels = torch.arange(batch_size, device=q_reps.device, dtype=torch.long)
    ce_loss = F.cross_entropy(scores, labels)
    l1_q = lambda_q * torch.mean(torch.sum(torch.abs(q_reps), dim=1))
    l1_d = lambda_d * torch.mean(torch.sum(torch.abs(all_docs), dim=1))
    return ce_loss + l1_q + l1_d

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = SPLADE().to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)

model.train()
for epoch in range(3):
    total_loss = 0
    running_loss = 0
    for i, batch in enumerate(tqdm(dataloader, desc=f'Epoch {epoch+1}')):
        optimizer.zero_grad()
        
        q_reps = model(batch['q_ids'].to(device), batch['q_mask'].to(device))
        p_reps = model(batch['p_ids'].to(device), batch['p_mask'].to(device))
        n_reps = model(batch['n_ids'].to(device), batch['n_mask'].to(device))
        
        loss = compute_loss(q_reps, p_reps, n_reps)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        running_loss += loss.item()
        
        if (i + 1)% 1000 == 0:
            torch.save(model.state_dict(), f"splade_v{(i + 1) // 1000}.pt")
            tqdm.write(f'Step {i+1}: avg loss = {running_loss/1000:.4f}')
            running_loss = 0
    
    tqdm.write(f'Epoch {epoch+1} avg loss: {total_loss/len(dataloader):.4f}')



Epoch 1:   0%|          | 1000/1243151 [06:06<564:04:13,  1.63s/it]

Step 1000: avg loss = 36.2805


Epoch 1:   0%|          | 2000/1243151 [12:10<572:23:30,  1.66s/it]

Step 2000: avg loss = 1.9583


Epoch 1:   0%|          | 3000/1243151 [18:14<560:55:14,  1.63s/it]

Step 3000: avg loss = 1.2410


Epoch 1:   0%|          | 4000/1243151 [24:18<564:20:55,  1.64s/it]

Step 4000: avg loss = 0.8858


Epoch 1:   0%|          | 5000/1243151 [30:23<576:30:30,  1.68s/it]

Step 5000: avg loss = 0.7511


Epoch 1:   0%|          | 5010/1243151 [30:27<125:27:47,  2.74it/s]


KeyboardInterrupt: 

In [None]:
import torch
import random
from transformers import AutoTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SPLADE().to(device)
model.load_state_dict(torch.load('splade_v5.pt'))
model.eval()

test_queries = [
    "what is python programming",
    "how to lose weight",
    "best restaurants in paris",
    "covid vaccine side effects",
    "machine learning tutorial",
    "climate change causes",
    "how to cook pasta",
    "bitcoin price prediction",
    "yoga benefits health",
    "electric cars pros cons"
]

reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

with torch.no_grad():
    for query in test_queries:
        tokens = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
        q_rep = model(tokens['input_ids'].to(device), tokens['attention_mask'].to(device))

        top_indices = torch.topk(q_rep[0], k=20).indices
        top_weights = torch.topk(q_rep[0], k=20).values
        
        print(f"\nQuery: {query}")
        print("Top tokens:", [(tokenizer.decode([idx]), f"{weight:.2f}") for idx, weight in zip(top_indices, top_weights)])


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: cf0b626c-a401-45a8-8a84-31a64e0176da)')' thrown while requesting HEAD https://huggingface.co/distilbert-base-uncased/resolve/main/config.json
Retrying in 1s [Retry 1/5].



Query: what is python programming
Top tokens: [('.', '2.81'), (',', '2.73'), ('!', '2.55'), ('?', '2.43'), (';', '2.42'), ('-', '2.37'), ("'", '2.30'), (':', '2.26'), ('...', '2.23'), (')', '2.22'), ('"', '1.86'), ('python', '1.53'), ('}', '0.68'), ('programming', '0.57'), ('program', '0.46'), ('the', '0.30'), ('bob', '0.10'), ('[PAD]', '0.00'), ('[unused0]', '0.00'), ('[unused1]', '0.00'), ('[unused2]', '0.00'), ('[unused3]', '0.00'), ('[unused4]', '0.00'), ('[unused5]', '0.00'), ('[unused6]', '0.00'), ('[unused7]', '0.00'), ('[unused8]', '0.00'), ('[unused9]', '0.00'), ('[unused10]', '0.00'), ('[unused11]', '0.00'), ('[unused12]', '0.00'), ('[unused13]', '0.00'), ('[unused14]', '0.00'), ('[unused15]', '0.00'), ('[unused16]', '0.00'), ('[unused17]', '0.00'), ('[unused18]', '0.00'), ('[unused19]', '0.00'), ('[unused20]', '0.00'), ('[unused21]', '0.00'), ('[unused22]', '0.00'), ('[unused23]', '0.00'), ('[unused24]', '0.00'), ('[unused25]', '0.00'), ('[unused26]', '0.00'), ('[unused27]'

In [26]:
model

SPLADE(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_featu

In [31]:
model.linear.weight.min()

tensor(-0.0658, device='cuda:0', grad_fn=<MinBackward1>)

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
from tqdm import tqdm

class SPLADE(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased'):
        super().__init__()
        self.bert_mlm = AutoModelForMaskedLM.from_pretrained(model_name)
        base_prefix = getattr(self.bert_mlm, "base_model_prefix", None)
        if base_prefix is None:
            for name, param in self.bert_mlm.named_parameters():
                if any(k in name for k in ("distilbert", "bert", "roberta", "base_model")):
                    param.requires_grad = False
                else:
                    param.requires_grad = True
        else:
            for name, param in self.bert_mlm.named_parameters():
                if name.startswith(base_prefix + "."):
                    param.requires_grad = False
                else:
                    param.requires_grad = True
    def forward(self, input_ids, attention_mask):
        outputs = self.bert_mlm(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        relu_logits = F.relu(logits)
        relu_logits = relu_logits * attention_mask.unsqueeze(-1)
        pooled, _ = torch.max(relu_logits, dim=1)
        weights = torch.log1p(pooled)
        return weights

    
def compute_loss(q_reps, p_reps, n_reps, lambda_q=1e-3, lambda_d=1e-4, tau=0.07):
    batch_size = q_reps.size(0)
    all_docs = torch.cat([p_reps, n_reps], dim=0)
    scores = torch.matmul(q_reps, all_docs.T) / tau
    labels = torch.arange(batch_size, device=q_reps.device, dtype=torch.long)
    ce_loss = F.cross_entropy(scores, labels)
    l1_q = lambda_q * torch.mean(torch.sum(torch.abs(q_reps), dim=1))
    l1_d = lambda_d * torch.mean(torch.sum(torch.abs(all_docs), dim=1))
    return ce_loss + l1_q + l1_d

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = SPLADE().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

model.train()
for epoch in range(3):
    total_loss = 0
    running_loss = 0
    for i, batch in enumerate(tqdm(dataloader, desc=f'Epoch {epoch+1}')):
        optimizer.zero_grad()
        q_reps = model(batch['q_ids'].to(device), batch['q_mask'].to(device))
        p_reps = model(batch['p_ids'].to(device), batch['p_mask'].to(device))
        n_reps = model(batch['n_ids'].to(device), batch['n_mask'].to(device))
        loss = compute_loss(q_reps, p_reps, n_reps, lambda_q=1e-3, lambda_d=1e-4, tau=0.07)
        loss.backward()
        
        running_loss += loss.item()
        
        if (i + 1) % 1000 == 0:
            torch.save(model.state_dict(), f"splade_working_v{(i + 1) // 1000}.pt")
            tqdm.write(f'Step {i+1}: avg loss = {running_loss/1000:.4f}')
            running_loss = 0
            
        optimizer.step()


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 5a89fdb4-83ac-4670-bdeb-486588b3282a)')' thrown while requesting HEAD https://huggingface.co/distilbert-base-uncased/resolve/main/tokenizer_config.json
Retrying in 1s [Retry 1/5].
Epoch 1:   0%|          | 1000/1243151 [03:13<543:06:32,  1.57s/it]

Step 1000: avg loss = 555.0217


Epoch 1:   0%|          | 2000/1243151 [06:23<503:51:35,  1.46s/it]

Step 2000: avg loss = 67.0492


Epoch 1:   0%|          | 2006/1243151 [06:25<66:12:53,  5.21it/s] 


KeyboardInterrupt: 