In [None]:
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import ir_datasets

class MSMARCODataset(Dataset):
    def __init__(self, tokenizer, max_length=128, max_items=None):
        self.tokenizer = tokenizer
        self.max_length = max_length
        dataset = ir_datasets.load("msmarco-passage/train/triples-small")
        
        queries = {}
        docs = {}
        for q in dataset.queries_iter():
            queries[q.query_id] = q.text
        for d in dataset.docs_iter():
            docs[d.doc_id] = d.text
        
        self.data = []
        count = 0
        for item in dataset.docpairs_iter():
            if max_items and count >= max_items:
                break
            self.data.append({
                'query': queries[item.query_id],
                'pos_doc': docs[item.doc_id_a],
                'neg_doc': docs[item.doc_id_b]
            })
            count += 1
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        query = self.tokenizer(item['query'], truncation=True, padding='max_length', 
                               max_length=self.max_length, return_tensors='pt')
        pos_doc = self.tokenizer(item['pos_doc'], truncation=True, padding='max_length',
                                 max_length=self.max_length, return_tensors='pt')
        neg_doc = self.tokenizer(item['neg_doc'], truncation=True, padding='max_length',
                                 max_length=self.max_length, return_tensors='pt')
        return {
            'q_ids': query['input_ids'].squeeze(),
            'q_mask': query['attention_mask'].squeeze(),
            'p_ids': pos_doc['input_ids'].squeeze(),
            'p_mask': pos_doc['attention_mask'].squeeze(),
            'n_ids': neg_doc['input_ids'].squeeze(),
            'n_mask': neg_doc['attention_mask'].squeeze()
        }

class SPLADE(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased'):
        super().__init__()
        self.bert_mlm = AutoModelForMaskedLM.from_pretrained(model_name)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert_mlm(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        relu_logits = F.relu(logits)
        relu_logits = relu_logits * attention_mask.unsqueeze(-1)
        pooled, _ = torch.max(relu_logits, dim=1)
        pooled = torch.clamp(pooled, max=10.0)
        return torch.log1p(pooled + 1e-8)

def compute_loss(q_reps, p_reps, n_reps, lambda_q=1e-3, lambda_d=1e-4, tau=0.1):
    batch_size = q_reps.size(0)
    all_docs = torch.cat([p_reps, n_reps], dim=0)
    
    q_reps = F.normalize(q_reps, p=2, dim=-1)
    all_docs = F.normalize(all_docs, p=2, dim=-1)
    
    scores = torch.matmul(q_reps, all_docs.T) / tau
    scores = torch.clamp(scores, min=-100, max=100)
    
    labels = torch.arange(batch_size, device=q_reps.device)
    ce_loss = F.cross_entropy(scores, labels)
    
    l1_q = lambda_q * torch.mean(torch.abs(q_reps))
    l1_d = lambda_d * torch.mean(torch.abs(all_docs))
    
    total_loss = ce_loss + l1_q + l1_d
    
    if torch.isnan(total_loss) or torch.isinf(total_loss):
        return torch.tensor(0.0, device=q_reps.device, requires_grad=True)
    
    return total_loss

parser = argparse.ArgumentParser(prog='splade_training')
parser.add_argument('-r', '--result', required=True, help='Output file')


if name == '__main__':
    args = parser.parse_args()
    
    tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
    
    print('Loading dataset...')
    dataset = MSMARCODataset(tokenizer, max_items=3 * 32000)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    print('Training set has {} instances'.format(len(dataset)))
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SPLADE().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
    scaler = torch.cuda.amp.GradScaler()
    
    model.train()
    for epoch in range(1):
        print('EPOCH {}:'.format(epoch + 1))
        total_loss = 0.0
        running = 0.0
        valid_steps = 0
        
        for i, batch in enumerate(dataloader):
            optimizer.zero_grad(set_to_none=True)
            
            try:
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    q_reps = model(batch['q_ids'].to(device), batch['q_mask'].to(device))
                    p_reps = model(batch['p_ids'].to(device), batch['p_mask'].to(device))
                    n_reps = model(batch['n_ids'].to(device), batch['n_mask'].to(device))
                    loss = compute_loss(q_reps, p_reps, n_reps)
                
                if not torch.isnan(loss) and not torch.isinf(loss):
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    
                    loss_val = loss.item()
                    total_loss += loss_val
                    running += loss_val
                    valid_steps += 1
                else:
                    scaler.update()
                    continue
                    
            except RuntimeError as e:
                scaler.update()
                continue
            
            if (i + 1) % 100 == 0 and valid_steps > 0:
                print('  batch {} loss: {}'.format(i + 1, running/valid_steps))
                running = 0.0
                valid_steps = 0
        
        if len(dataloader) > 0:
            print('Epoch {} avg loss = {}'.format(epoch + 1, total_loss / len(dataloader)))
    
    model.eval()
    test_queries = [
        "what is python programming",
        "how to lose weight",
        "best restaurants in paris",
        "covid vaccine side effects",
        "machine learning tutorial",
        "climate change causes",
        "how to cook pasta",
        "bitcoin price prediction",
        "yoga benefits health",
        "electric cars pros cons"
    ]
    
    with torch.no_grad():
        for query in test_queries:
            tokens = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
            q_rep = model(tokens['input_ids'].to(device), tokens['attention_mask'].to(device))
            
            top_indices = torch.topk(q_rep[0], k=20).indices
            top_weights = torch.topk(q_rep[0], k=20).values
            
            print('\nQuery: {}'.format(query))
            print('Top tokens: {}'.format([(tokenizer.decode([idx]), '{:.2f}'.format(weight.item())) 
                                          for idx, weight in zip(top_indices, top_weights)]))
    
    model_path = 'splade_checkpoint.pt'
    torch.save(model.state_dict(), model_path)
    
    print('Move {} to {}'.format(model_path, args.result))
    shutil.move(model_path, args.result)

![Learning loss](docs/loss_plot.png)

## Query encoding example:

```python
{
    
    "what is python programming": [('python', '1.97'), ('software', '0.78'), ('programming', '0.59'), ('language', '0.53'), ('logic', '0.44'), ('languages', '0.30'), ('technology', '0.24'), ('computer', '0.19'), ('c', '0.07'), ('[unused9]', '0.00'), ('[unused6]', '0.00'), ('[unused8]', '0.00'), ('[unused4]', '0.00'), ('[unused2]', '0.00'), ('[unused5]', '0.00'), ('[unused7]', '0.00'), ('[unused3]', '0.00'), ('[unused1]', '0.00'), ('[PAD]', '0.00'), ('[unused0]', '0.00')],

    "how to lose weight": [('loss', '1.29'), ('exercise', '1.26'), ('weight', '1.14'), ('yoga', '0.93'), ('diet', '0.89'), ('washing', '0.70'), ('lose', '0.49'), ('eating', '0.42'), ('daily', '0.32'), ('pills', '0.24'), ('change', '0.21'), ('losing', '0.19'), ('gain', '0.17'), ('training', '0.16'), ('body', '0.12'), ('[unused3]', '0.00'), ('[PAD]', '0.00'), ('[unused2]', '0.00'), ('[unused0]', '0.00'), ('[unused1]', '0.00')],

    "best restaurants in paris": [('restaurant', '1.48'), ('paris', '1.47'), ('hotel', '1.38'), ('france', '1.20'), ('french', '0.79'), ('food', '0.75'), ('restaurants', '0.71'), ('hotels', '0.58'), ('resort', '0.55'), ('club', '0.33'), ('europe', '0.29'), ('museums', '0.26'), ('store', '0.23'), ('town', '0.08'), ('shop', '0.07'), ('spain', '0.05'), ('[unused2]', '0.00'), ('[unused1]', '0.00'), ('[PAD]', '0.00'), ('[unused0]', '0.00')]

    "covid vaccine side effects": [('##vid', '1.99'), ('effects', '1.57'), ('co', '1.43'), ('nausea', '1.29'), ('poisoning', '1.25'), ('vaccine', '1.17'), ('headache', '0.60'), ('effect', '0.54'), ('pills', '0.44'), ('influenza', '0.17'), ('virus', '0.08'), ('[unused7]', '0.00'), ('[unused3]', '0.00'), ('[unused6]', '0.00'), ('[unused4]', '0.00'), ('[unused5]', '0.00'), ('[unused1]', '0.00'), ('[unused0]', '0.00'), ('[unused2]', '0.00'), ('[PAD]', '0.00')],

    "machine learning tutorial": [('machine', '1.72'), ('learning', '1.18'), ('education', '0.72'), ('training', '0.65'), ('skills', '0.64'), ('technology', '0.58'), ('teaching', '0.57'), ('language', '0.28'), ('software', '0.23'), ('instruction', '0.18'), ('science', '0.17'), ('theory', '0.11'), ('method', '0.11'), ('cognitive', '0.08'), ('equipment', '0.05'), ('school', '0.04'), ('degree', '0.03'), ('process', '0.02'), ('[unused0]', '0.00'), ('[PAD]', '0.00')],

    "climate change causes": [('climate', '1.69'), ('cause', '1.29'), ('impact', '1.19'), ('causes', '1.12'), ('warming', '1.04'), ('effect', '0.91'), ('change', '0.89'), ('causing', '0.88'), ('earth', '0.87'), ('effects', '0.74'), ('pollution', '0.71'), ('environment', '0.62'), ('increase', '0.58'), ('emissions', '0.58'), ('economy', '0.51'), ('global', '0.48'), ('cold', '0.29'), ('atmosphere', '0.24'), ('factors', '0.21'), ('source', '0.19')]

    "how to cook pasta": [('pasta', '1.80'), ('cooking', '0.89'), ('baking', '0.71'), ('cook', '0.69'), ('hours', '0.41'), ('boiling', '0.38'), ('oven', '0.24'), ('heat', '0.16'), ('[unused10]', '0.00'), ('[unused9]', '0.00'), ('[unused7]', '0.00'), ('[unused8]', '0.00'), ('[unused4]', '0.00'), ('[unused3]', '0.00'), ('[unused5]', '0.00'), ('[unused6]', '0.00'), ('[unused2]', '0.00'), ('[unused1]', '0.00'), ('[PAD]', '0.00'), ('[unused0]', '0.00')],
    
    "bitcoin price prediction": [('price', '1.34'), ('prediction', '1.25'), ('bit', '1.22'), ('stock', '1.21'), ('##co', '0.98'), ('value', '0.75'), ('percentage', '0.49'), ('bank', '0.40'), ('increase', '0.40'), ('rate', '0.19'), ('market', '0.18'), ('fortune', '0.17'), ('cash', '0.13'), ('accounting', '0.06'), ('[unused4]', '0.00'), ('[unused3]', '0.00'), ('[PAD]', '0.00'), ('[unused0]', '0.00'), ('[unused1]', '0.00'), ('[unused2]', '0.00')]

    "yoga benefits health": [('yoga', '2.03'), ('benefits', '0.97'), ('cure', '0.76'), ('exercise', '0.63'), ('healing', '0.57'), ('health', '0.44'), ('meditation', '0.37'), ('body', '0.33'), ('training', '0.11'), ('wellness', '0.06'), ('diet', '0.06'), ('mental', '0.06'), ('therapy', '0.05'), ('breathing', '0.01'), ('[unused4]', '0.00'), ('[unused3]', '0.00'), ('[PAD]', '0.00'), ('[unused0]', '0.00'), ('[unused1]', '0.00'), ('[unused2]', '0.00')],

    "electric cars pros cons": [('electric', '1.49'), ('vehicle', '1.12'), ('car', '0.81'), ('pro', '0.73'), ('con', '0.62'), ('engine', '0.46'), ('bmw', '0.42'), ('toyota', '0.30'), ('tesla', '0.25'), ('ford', '0.23'), ('power', '0.19'), ('driving', '0.17'), ('gear', '0.11'), ('fuel', '0.05'), ('ev', '0.04'), ('[unused3]', '0.00'), ('[PAD]', '0.00'), ('[unused2]', '0.00'), ('[unused0]', '0.00'), ('[unused1]', '0.00')]

}