In [1]:
!pip install transformers
!pip install torch


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 7.3 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 56.8 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 12.8 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.10.1 tokenizers-0.13.2 transformers-4.24.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
# This is a basic example use. It is exactly what is given in the huggingface model card

from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification, TextClassificationPipeline
import torch

def preprocess(comment):
  comment = ''.join(e for e in comment if (e.isalnum() or e == ' '))
  return comment

tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" # use whatever device works for you

num_labels = 12
model = (AutoModelForSequenceClassification.from_pretrained("skylergrandel/Comcat", num_labels=num_labels).to(DEVICE))
classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, device=DEVICE)

prediction = classifier(preprocess('/* This is a comment that we want to classify */'))
print('Prediction:', prediction)

Prediction: [{'label': 'LABEL_5', 'score': 0.6703650951385498}]


In [19]:
# Note that index 0 and 8 are taken from a dataset of unlabeled comments, because I didn't want to write a long comment
# All of the other ones were just made up
text = [
    '''/**
* Comparison function for tr_announce_requests.
*
* The primary key (amount of data transferred) is used to prioritize
* tracker announcements of active torrents. The remaining keys are
* used to satisfy the uniqueness requirement of a sorted tr_ptrArray.
*/''',
    '/* sum contains the sum of all of the numbers in the list */',
    '/* walk through the list, turning all of the 0s to -1s */',
    '/* if the number is even, put it in A; otherwise, put it in B */',
    '/* We loop in reverse because it is the most efficient way */',
    '/* This is a useless comment */',
    '/* This should be called by a collector thread */',
    '/* from example in the glib documentation at www.examplewebsite.com/glibdocs/get */',
    '''
    /* -*- Mode: C; tab-width: 8; indent-tabs-mode: t; c-basic-offset: 8; coding: utf-8 -*- *//*
      * This file is part of GtkSourceView
      *
      * Copyright (C) 2003 - Gustavo Giráldez
      * Copyright (C) 2006, 2013 - Paolo Borelli
      * Copyright (C) 2013, 2016 - Sébastien Wilmet
      *
      * GtkSourceView is free software; you can redistribute it and/or
      * modify it under the terms of the GNU Lesser General Public
      * License as published by the Free Software Foundation; either
      * version 2.1 of the License, or (at your option) any later version.
      *
      * GtkSourceView is distributed in the hope that it will be useful,
      * but WITHOUT ANY WARRANTY; without even the implied warranty of
      * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      * Lesser General Public License for more details.
      *
      * You should have received a copy of the GNU Lesser General Public
      * License along with this library; if not, write to the Free Software
      * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
      */
    ''',
    '/* The following tests are to test foo() */',
    '/* int y = x+10 */',
    '/* TODO: add failure test cases */',
]

In [20]:
for i in range(len(text)):
  prediction = classifier(preprocess(text[i]))
  print('Prediction for text[', i, ']:', prediction)

Prediction for text[ 0 ]: [{'label': 'LABEL_0', 'score': 0.8686641454696655}]
Prediction for text[ 1 ]: [{'label': 'LABEL_1', 'score': 0.7577952146530151}]
Prediction for text[ 2 ]: [{'label': 'LABEL_2', 'score': 0.9838972687721252}]
Prediction for text[ 3 ]: [{'label': 'LABEL_3', 'score': 0.9059826135635376}]
Prediction for text[ 4 ]: [{'label': 'LABEL_4', 'score': 0.4900676906108856}]
Prediction for text[ 5 ]: [{'label': 'LABEL_5', 'score': 0.9692336916923523}]
Prediction for text[ 6 ]: [{'label': 'LABEL_6', 'score': 0.9262851476669312}]
Prediction for text[ 7 ]: [{'label': 'LABEL_7', 'score': 0.2945345342159271}]
Prediction for text[ 8 ]: [{'label': 'LABEL_8', 'score': 0.4667740762233734}]
Prediction for text[ 9 ]: [{'label': 'LABEL_9', 'score': 0.4376959204673767}]
Prediction for text[ 10 ]: [{'label': 'LABEL_10', 'score': 0.6370751261711121}]
Prediction for text[ 11 ]: [{'label': 'LABEL_11', 'score': 0.9256995916366577}]


In [24]:
# Of course, some categories perform better than others. 
# For example, we have limited training data for categories 7 and 8

# One interesting failure case is when we use informal language for one of the more useful categories, 
# because the model seems to think that comments with informal or indirect language belong in category 5

# This is a failure case that I made up
print('This should be LABEL_2:', classifier(preprocess('/* We just kinda want to mix around the letters in this string */')))

# This is a real failure case from the test set
print('This should be LABEL_4:', classifier(preprocess('/* Windows 98 appears to asynchronously create and remove  *//* writable memory mappings, for reasons we havent yet    *//* understood.  Since we look for writable regions to      *//* determine the root set, we may try to mark from an      *//* address range that disappeared since we started the     *//* collection.  Thus we have to recover from faults here.  *//* This code does not appear to be necessary for Windows   *//* 95/NT/2000. Note that this code should never generate   *//* an incremental GC write fault.                          */')))

This should be LABEL_2: [{'label': 'LABEL_5', 'score': 0.9368220567703247}]
This should be LABEL_4: [{'label': 'LABEL_5', 'score': 0.9527665972709656}]
