Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get rid of some class on PR-curve #3642

Closed
snow-tyan opened this issue Jun 16, 2021 · 2 comments
Closed

Get rid of some class on PR-curve #3642

snow-tyan opened this issue Jun 16, 2021 · 2 comments
Labels
question Further information is requested

Comments

@snow-tyan
Copy link

❔Question

I trained bdd100k-dataset use yolov5m.yaml
python train.py --cfg yolov5m.yaml --data/bdd100k.yaml --weights yolov5m.pt --epochs 50 --hyp hyp.finetune.yaml --rect
And I got about 40% mAP@.5
Class AP50
person | 0.5080
rider | 0.1930
car | 0.6730
bus | 0.4480
truck | 0.5050
bike | 0.3930
motor | 0.2580
tra-l | 0.4770
tra-s | 0.5210
train | 0.0000

Additional context

PR_curve

I notice that some classes have very few labels like train. How to draw a new PR-curve without retraining
Get a new 44.18% mAP@.5 (without 'train')

It seems a bit difficult :(

@snow-tyan snow-tyan added the question Further information is requested label Jun 16, 2021
@glenn-jocher
Copy link
Member

@snow-tyan no this is not possible. AP per class is displayed along with mAP for all classes in your dataset.

@snow-tyan
Copy link
Author

snow-tyan commented Jun 17, 2021

@glenn-jocher I tried to modify utils/metrics.py plot_pr_curve function and got cls9-PR-curve cls7-PR-curve
Emmm, at least it looks better :)

def plot_pr_curve_cls9(px, py, ap, save_dir='pr_curve.png', names=()):
    # PR cls9
    fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
    # py (1000,nc)
    py = np.stack(py, axis=1)

    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
        for i, y in enumerate(py.T):  # py.T (nc, 1000)
            # names = {0:'person', 1:'rider', 2:'car', ...}
            # get rid of 9:'train'
            if names[i] != 'train':
                ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}')  # plot(recall, precision)
    else:
        ax.plot(px, py, linewidth=1, color='grey')  # plot(recall, precision)

    py_cls9 = py[:, 0:9]  # (1000,9)

    # (9, 10)
    ap50_cls9 = []
    ap50_cls9 += ap[0:9, 0].tolist()
    ap50_cls9 = np.array(ap50_cls9)

    ax.plot(px, py_cls9.mean(1), linewidth=3, color='blue', label='mAP@0.5 %.3f ' % ap50_cls9.mean())
    # ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    fig.savefig(Path(save_dir), dpi=250)

cls9-PR_curve
cls7-PR_curve

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants